#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author ：hhx
@Date ：2022/5/19 21:33
@Description ：AE编码器（最简单）
"""
import numpy as np
import os
from utils import *
import torch
from torch import nn, optim
from torch.utils import data
from models import AE, AE_withLinear, UNet
from tqdm import tqdm
from PIL import Image
# import os
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = 'cpu'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.DoubleTensor)


if __name__ == '__main__':
    batch = 8
    datasetpath = 'dataset\TIF'
    # datasetpath = 'D:\日常\根据波段指数变化检测'
    trainSet = CarTiffDateSet(datasetpath)
    train_loader = torch.utils.data.DataLoader(dataset=trainSet,
                                               batch_size=batch,
                                               shuffle=True)

    # model = AE_withLinear.Autoencoder().to(device)
    model = AE_withLinear.Autoencoder().to(device)
    cost2 = nn.MSELoss().to(device)
    optimizer2 = optim.Adam(model.parameters(), weight_decay=1e-6)
    for epoch in range(1, 30):
        model.train()
        index = 0
        loss = 0
        for images, labels in tqdm(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs2 = model(images)
            loss2 = cost2(outputs2, labels)
            optimizer2.zero_grad()
            loss2.backward()
            loss += loss2
            optimizer2.step()
            index += 1
        print(loss)
    # torch.save(model.state_dict(), 'SavedModels/UNet.pkl')
    torch.save(model.state_dict(), 'SavedModels/AE_withLinear.pkl')
